import torch
from utils.data_class import PVTrainDataSet,PVTestDataSet
import numpy as np


class Sim1d_noX:
    def __init__(self,seeds=43,size=1000):
        self.seeds = seeds
        self.size = size


    def generatate_sim(self,totensor: bool = False, beta =1,sigma=1, W_miss=False, Z_miss = False):
        np.random.seed(self.seeds)
        U2 = np.random.uniform(-1, 2, size = self.size)
        U1 = np.random.uniform(0, 1,size = self.size) - ((U2 >= 0) & (U2 <= 1)).astype(int)
        Z2 = U2 + np.random.uniform(-1, 1, size = self.size)
        Z1 = U1 + np.random.normal(0,sigma,size = self.size)

        if Z_miss == True:
            Z = np.sqrt(np.abs(np.c_[Z1, Z2]))+1
        else:
            Z = np.c_[Z1, Z2]

        W1 = U1 + np.random.uniform(-1, 1, size = self.size)
        W2 = U2 + np.random.normal(0,sigma,size = self.size)

        if W_miss == True:
            W = np.sqrt(np.abs(np.c_[W1, W2]))+1
        else:
            W = np.c_[W1, W2]
        
        A = U2 + np.random.normal(0,beta,size = self.size)
        Y = 3*np.cos(2 * ( 0.3 * U1 +0.3 *U2 + 0.2)+1.5*A) + np.random.normal(0,1,size = self.size)

        if totensor:
            return PVTrainDataSet(treatment=torch.tensor(A[:, np.newaxis], dtype=torch.float32),
                                    treatment_proxy=torch.tensor(Z, dtype=torch.float32),
                                    outcome_proxy=torch.tensor(W, dtype=torch.float32),
                                    outcome=torch.tensor(Y[:, np.newaxis], dtype=torch.float32),
                                    backdoor=None)
        else:
            return PVTrainDataSet(treatment=A[:, np.newaxis],
                                    treatment_proxy=Z,
                                    outcome_proxy=W,
                                    outcome=Y[:, np.newaxis],
                                    backdoor=None)
        
    @staticmethod
    def generate_test(size,seed=43,totensor=False,beta =1,sigma=1) -> None:
        np.random.seed(seed)
        U2 = np.random.uniform(-1, 2, size = size)
        U1 = np.random.uniform(0, 1,size = size) - ((U2 >= 0) & (U2 <= 1)).astype(int)
        Z2 = U2 + np.random.uniform(-1, 1, size = size)
        Z1 = U1 + np.random.normal(0,sigma,size = size)
        Z = np.c_[Z1, Z2]
        W1 = U1 + np.random.uniform(-1, 1, size = size)
        W2 = U2 + np.random.normal(0,sigma,size =size)
        W = np.c_[W1, W2]
        A = U2 + np.random.normal(0,beta,size = size)
        Y = 3*np.cos(2 * ( 0.3 * U1 +0.3 *U2 + 0.2)+1.5*A) + np.random.normal(0,1,size = size)

        if totensor:
            return PVTestDataSet(treatment=torch.tensor(A[:, np.newaxis], dtype=torch.float32),
                                    treatment_proxy=torch.tensor(Z, dtype=torch.float32),
                                    outcome_proxy=torch.tensor(W, dtype=torch.float32),
                                    outcome=torch.tensor(Y[:, np.newaxis], dtype=torch.float32),
                                    backdoor=None)
        else:
            return PVTestDataSet(treatment=A[:, np.newaxis],
                                    treatment_proxy=Z,
                                    outcome_proxy=W,
                                    outcome=Y[:, np.newaxis],
                                    backdoor=None)

    @staticmethod
    def generate_test_effect(a,b,c):
        A = np.linspace(a, b, c)
        U2 = np.random.uniform(-1, 2, size = 10000)
        U1 = np.random.uniform(0, 1,size = 10000) - ((U2 >= 0) & (U2 <= 1)).astype(int)
        treatment = np.array([np.mean(3 * np.cos(2 * ( 0.3 * U1+0.3 * U2 + 0.2) + 1.5*a)) for a in A])
        return A,treatment
        